# ------------------------------------------------------------------------------
# Generates model predictions in the validation sets
# Author: Cassidy Shubatt <cshubatt@gmail.com>
# To run: bsub -q big -R "rusage[mem=25000]" bash 05_predict_ensemble_components.sh {split} {restriction}
# ------------------------------------------------------------------------------

# Seeding ----------------------------------------------------------------------
set.seed(1)

# Libraries --------------------------------------------------------------------
message("Loading libraries...")
library(yaml)
library(data.table)
library(tidyverse)
library(glue)
library(Matrix)
library(glmnet)
library(xgboost)
library(optparse)
library(testit) # assert
library(here) # here() relative filepaths

u <- modules::use(here::here("lib", "util.R"))

# Command Line Args ------------------------------------------------------------
arg_config <- list(
  make_option("--split", type = "character"),
  make_option("--restriction", type = "character")
)
arg_parser <- OptionParser(option_list = arg_config)
arg_list <- parse_args(arg_parser)
split <- arg_list$split

# Helpers ----------------------------------------------------------------------
get_score_name <- function(model_file, ...) {
  clean <- str_remove(model_file, "__[12345].rds")
  restriction_lab <- ifelse(
    arg_list$restriction == "all", "", glue("__{arg_list$restriction}")
  )
  glue("p__{clean}{restriction_lab}")
}

get_score_name_log <- function(model_file, ...) {
  clean <- str_remove(model_file, "__[12345].rds")
  restriction_lab <- ifelse(
    arg_list$restriction == "all", "", glue("__{arg_list$restriction}")
  )
  glue("z__{clean}{restriction_lab}")
}

# Directories ------------------------------------------------------------------
message("Establishing directories...")
paths <- read_yaml(here("lib", "filepaths.yml"))
cohort_dir <- file.path(paths$modeling$oos, "cohorts", arg_list$split)
features_dir <- file.path(paths$features$dir, arg_list$split)
models_dir <- file.path(
  paths$modeling$oos, "models", arg_list$split, arg_list$restriction
)
subscores_dir <- file.path(
  paths$modeling$oos, "subscores", arg_list$split, arg_list$restriction
)
assert("Subscores directory exists", dir.exists(subscores_dir))

dataset <- arg_list$dataset

build_models_dir <- here::here("code", "03_analysis", "01_build_models")

# Locate Models ----------------------------------------------------------------
message("Locating models...")
model_files <- list.files(models_dir) %>%
  .[!grepl("ensemble__", .)]
gbm_files <- str_subset(model_files, "gbm__")
lasso_files <- str_subset(model_files, "lasso__")

# Get Subscores ----------------------------------------------------------------
downsamples <- c("tested", "ds_mace", "ds_test", "non_downsampled")
for (i in 1:5) {
  message("Getting ensemble subscores for ensemble ", i, "...")
  score_names <- map(model_files, get_score_name) %>% unlist %>%
    union(map(model_files, get_score_name_log) %>% unlist)
  keep_cols <- c(
    "ptid", "ed_enc_id", "train_fold", "in_ensemble",
    "excl_flag_c_int", "excl_flag_chronic", "excl_flag_death"
  )
  subscore_colnames <- c(keep_cols, score_names)
  subscores <- data.frame(matrix(ncol = length(subscore_colnames), nrow = 0))
  colnames(subscores) <- subscore_colnames
  # don't want to repeat IDs occurring in multiple downsamples
  used_ids <- c()
  for(ds in downsamples){
    message("Loading data for ", ds, "...")
    ids <- readRDS(file.path(cohort_dir, glue("train_cohort_{ds}.rds")))
    x <- readRDS(file.path(features_dir, glue("train_features_{ds}.rds")))

    if(arg_list$restriction == "dropcc"){
      # drop chief complaint features
      keep_feats <- which(!grepl("ed_enc_t0d", colnames(x)))
      x <- x[, keep_feats]
    }else if(arg_list$restriction == "justcc"){
      keep_feats <- which(grepl("ed_enc_t0d", colnames(x)))
      x <- x[, keep_feats]
    }else if(arg_list$restriction == "dem"){
      keep_feats <- which(grepl("dem_", colnames(x)))
      x <- x[, keep_feats]
    }else if(arg_list$restriction == "enc"){
      keep_feats <- which(grepl("enc_", colnames(x)) & !grepl("_cc_", colnames(x)))
      x <- x[, keep_feats]
    }
  # ensemble set used for fold i is patients w "in_ensemble" flag but not in i
    keep <- which((ids$train_fold == i | ids$in_ensemble) & !(ids$ed_enc_id %in% used_ids))
    ids <- ids[keep, ] %>%
      select(all_of(keep_cols)) %>%
      setDT()
    x <- x[keep, ]
    used_ids <- c(used_ids, ids$ed_enc_id)

    # get models
    lasso_files_fold <- str_subset(lasso_files, glue("__{i}"))
    gbm_files_fold <- str_subset(gbm_files, glue("__{i}"))

    # GBMs ---------------------------------------------------------------------
    message("Predicting with GBMS...")
    gbm_models <- map(
      file.path(models_dir, gbm_files_fold),
      xgb.load
    )
    model <- gbm_models[[1]]
    predictions <- predict(model, x, outputmargin = FALSE)

    walk2(
      get_score_name(gbm_files_fold, i), gbm_models,
      ~ ids[, (.x) := predict(.y, x, outputmargin = FALSE)]
    )
    walk2(
      get_score_name_log(gbm_files_fold, i), gbm_models,
      ~ ids[, (.x) := predict(.y, x, outputmargin = TRUE)]
    )
    rm(gbm_models)

    # LASSOs -------------------------------------------------------------------
    message("Predicting with LASSOs...")
    lasso_models <- map(
      file.path(models_dir, lasso_files_fold),
      readRDS
    )
    walk2(
      get_score_name(lasso_files_fold, i), lasso_models,
      ~ ids[, (.x) := predict(.y, x, s = "lambda.min", type = "response")]
    )
    walk2(
      get_score_name_log(lasso_files_fold, i), lasso_models,
      ~ ids[, (.x) := predict(.y, x, s = "lambda.min", type = "link")]
    )

    rm(lasso_models)

    subscores <- rbind(subscores, ids)
  }
  message("Saving subscores for fold ", i, " downsample ", "...")
  write_rds(subscores, file.path(subscores_dir, glue("subscores__{i}.rds")))
}

message("Done.")
